'''
 Authors: Bowen Wen
 Contact: wenbowenxjtu@gmail.com
 Created in 2018

 Copyright (c) Bowen Wen, Rutgers University, 2018 
'''


'''
evaluate precision, recall, f1...
'''



import numpy as np
import scipy.io
import cv2
import os,sys
import functions as fc
import math


def land_classify_id(coord,ld_classify_v,ld_classify_u):
    v,u=coord
    id=np.where(ld_classify_v==v)
    if id is None: 
        print('id not found error!')
        exit(1)
    num_find=len(id[0])
    for i in range(num_find):
        if ld_classify_u[id[0][i],id[1][i]]==u:
            return id[0][i]

    



# read data
test_img=cv2.imread('../../landClassification/ReturnResult/Img1_classification_resutl.jpg')
img=test_img.copy()

corner=fc.get_corner(test_region=True)
cv2.rectangle(img,(corner[0],corner[2]),(corner[1],corner[3]),(0,0,0),15)
groundtruth = scipy.io.loadmat('../../data/groundtruth(u_v).mat')
groundtruth=np.array(groundtruth['x'],dtype=float)
print('corner',corner)
print('groundtruth shape:',groundtruth.shape)

## get land classification bboxes. 6 dims:[red(trees), green(cultivated agricultural land),  pink(cloud), yellow(uncultivated agricultural land), blue(building/infrastructure), black(water)]
ld_classify_v=scipy.io.loadmat('../../landClassification/stat_result/x_vgg.mat')['x']
ld_classify_u=scipy.io.loadmat('../../landClassification/stat_result/y_vgg.mat')['y']
# ld_classify_v=scipy.io.loadmat('../../landClassification/stat_result/x_sift.mat')['x']
# ld_classify_u=scipy.io.loadmat('../../landClassification/stat_result/y_sift.mat')['y']
# ld_classify_v=scipy.io.loadmat('../../landClassification/stat_result/x_bow.mat')['x']
# ld_classify_u=scipy.io.loadmat('../../landClassification/stat_result/y_bow.mat')['y']
ld_classify_v -= 32
ld_classify_u -= 32
ld_classify_v=np.array(ld_classify_v, dtype=np.int)
ld_classify_u=np.array(ld_classify_u, dtype=np.int)

land_total=np.zeros(6)
land_truedetect=np.zeros(6)
land_groundtruth=np.array([4,143,1,23,5,1])


# img_path='../segmented_classify/RF/output_region_VGG'
img_path='../segmented_classify/RF/output_region_central_sift'
# img_path='../segmented_classify/RF/output_region_bow'
# img_path = '../hog_color_pca_results/posStore/'


img_names=os.listdir(img_path)
counter=0
true_pos=0
false_pos=0
total_detection=0
total_craters=groundtruth.shape[0]
groundtruth_flag=np.zeros((groundtruth.shape[0]))
remove_overlap_=str(input('do you want to remove overlap?  (y/n)'))

# TODO: when overlap==false, landclassification not implemented yet!!!
if remove_overlap_=='n':
    for name in img_names:
        counter+=1
        name=os.path.splitext(name)[0]  # name is (v,u)
        name=name.split('_')
        u=int(name[1])+32  # shift from top left to center
        v=int(name[0])+32
        if u<corner[0] or u>corner[1] or v<corner[2] or v>corner[3]:
            continue
        min_dist=99999999
        total_detection+=1
        ground_truth_id=None
        for i in range(groundtruth.shape[0]):
            u_truth=groundtruth[i,0]
            v_truth=groundtruth[i,1]
            dist=np.sqrt((u-u_truth)**2+(v-v_truth)**2)
            if dist<min_dist:
                min_dist=dist
                ground_truth_id=i
        if min_dist<=30:  # true pos
            if groundtruth_flag[ground_truth_id]==0:   # first detection
                groundtruth_flag[ground_truth_id]=1
                cv2.rectangle(img,(u-32,v-32),(u+32,v+32),(0,0,255),10)
                true_pos+=1
            else:  # multiple detection, ignore
                total_detection-=1
        else:
            false_pos+=1
            cv2.rectangle(img,(u-32,v-32),(u+32,v+32),(255,0,0),10)
elif remove_overlap_=='y':
    coords=fc.remove_overlap(img_path,8)
    print('overlap removed... now shape = ', coords.shape)
    if (len(coords)==0):
        print('coords wrong \n\n')
        exit(1)

    cnt = 0
    for coord in coords:
        cnt+=1
        print('processing {}/{}'.format(cnt, len(coords)))
        landtype = land_classify_id(coord,ld_classify_v, ld_classify_u)
        land_total[landtype] += 1
        counter+=1
        u=coord[1]+32  # shift from top left to center
        v=coord[0]+32
        if u<corner[0] or u>corner[1] or v<corner[2] or v>corner[3]:
            continue
        min_dist=99999999
        total_detection+=1
        ground_truth_id=None
        for i in range(groundtruth.shape[0]):
            u_truth=groundtruth[i,0]
            v_truth=groundtruth[i,1]
            dist=np.sqrt((u-u_truth)**2+(v-v_truth)**2)
            if dist<min_dist:
                min_dist=dist
                ground_truth_id=i
        if min_dist<=30:  # true pos
            if groundtruth_flag[ground_truth_id]==0:   # first detection
                groundtruth_flag[ground_truth_id]=1   # flag as detected
                cv2.rectangle(img,(u-32,v-32),(u+32,v+32),(0,0,255),10)
                true_pos+=1
                land_truedetect[landtype] += 1
            else:  # multiple detection, ignore
                total_detection-=1
        else:
            false_pos+=1
            cv2.rectangle(img,(u-32,v-32),(u+32,v+32),(255,0,0),10)


img_test = img[corner[2]:corner[3],corner[0]:corner[1]]
cv2.imwrite('C:\\Users\\hasee\\Desktop\\rectangle.png',img_test)
precision=true_pos/float(total_detection)
recall=true_pos/float(total_craters)
f1=2/(1/precision+1/recall)
print('total_detection={}, total_caters={}, true_pos={}'.format(total_detection,total_craters,true_pos))
print('precision={}, recall={}, f1={}'.format(precision, recall, f1))

print('land classification info:')
print('land_truedetect:', land_truedetect)
print('land_total:',land_total)
print('land_groundtruth:',land_groundtruth)
